import numpy as np
import pandas as pd
from openai import OpenAI
from sklearn.metrics.pairwise import cosine_similarity
import pickle
# Initialize the OpenAI client
client = OpenAI(api_key= #INSERT API KEY HERE
                )

load_from_pickle = True

def get_embedding(text):
    response = client.embeddings.create(
        model="text-embedding-3-small",
        input=text
    )
    return response.data[0].embedding

def find_representative_quotes(quotes, num_representatives=5):
    # Get embeddings for all quotes
    embeddings = [get_embedding(quote) for quote in quotes]

    # Calculate the mean embedding
    mean_embedding = np.mean(embeddings, axis=0)

    # Calculate cosine similarity between mean embedding and each quote embedding
    similarities = cosine_similarity([mean_embedding], embeddings)[0]

    # Create a DataFrame with quotes and their similarities
    df = pd.DataFrame({
        'quote': quotes,
        'similarity': similarities
    })

    # Sort the DataFrame by similarity in descending order
    df = df.sort_values('similarity', ascending=False).reset_index(drop=True)

    return df


# Example usage
quotes = [
    "To be or not to be, that is the question.",
    "All the world's a stage, and all the men and women merely players.",
    "The lady doth protest too much, methinks.",
    "Romeo, Romeo, wherefore art thou Romeo?",
    "Now is the winter of our discontent.",
    "Friends, Romans, countrymen, lend me your ears.",
    "The course of true love never did run smooth.",
    "If music be the food of love, play on.",
    "What's in a name? That which we call a rose by any other name would smell as sweet.",
    "Some are born great, some achieve greatness, and some have greatness thrust upon them."
]

result_df = find_representative_quotes(quotes)

print("Quotes ranked by similarity to mean embedding:")
print(result_df)

# If you still want to get the top N most representative quotes:
num_representatives = 5
top_quotes = result_df.head(num_representatives)

print(f"\nTop {num_representatives} most representative quotes:")
for i, row in top_quotes.iterrows():
    print(f"{i+1}. {row['quote']} (Similarity: {row['similarity']:.4f})")



def find_uniquely_representative_quote(df):
    # Separate quotes into two groups
    group_1 = df[df['pair_includes_trans'] == 1]['quote'].tolist()
    group_0 = df[df['pair_includes_trans'] == 0]['quote'].tolist()

    # Get embeddings for all quotes
    print("embeddings_1")
    embeddings_1 = [get_embedding(quote) for quote in group_1]
    print("embeddings_0")
    embeddings_0 = [get_embedding(quote) for quote in group_0]

    # Calculate mean embeddings for each group
    mean_embedding_1 = np.mean(embeddings_1, axis=0)
    mean_embedding_0 = np.mean(embeddings_0, axis=0)

    # Calculate similarities to both group means for group 1 quotes
    similarities_1 = cosine_similarity(embeddings_1, [mean_embedding_1]). flatten()
    similarities_0 = cosine_similarity(embeddings_1, [mean_embedding_0]). flatten()

    # Calculate the difference in similarities
    similarity_diff = similarities_1 - similarities_0

    # Add the similarity diff to the data frame
    df_group_1 = df[df['pair_includes_trans'] == 1].copy()
    df_group_1['similarity_diff'] = similarity_diff

    return df_group_1


# Example usage
df = pd.DataFrame({
    'quote': [
        "To be or not to be, that is the question.",
        "All the world's a stage, and all the men and women merely players.",
        "The lady doth protest too much, methinks.",
        "Romeo, Romeo, wherefore art thou Romeo?",
        "Now is the winter of our discontent.",
        "Friends, Romans, countrymen, lend me your ears.",
        "The course of true love never did run smooth.",
        "If music be the food of love, play on.",
        "What's in a name? That which we call a rose by any other name would smell as sweet.",
        "Some are born great, some achieve greatness, and some have greatness thrust upon them."
    ],
    'pair_includes_trans': [1, 0, 1, 0, 1, 0, 1, 0, 1, 0]  # Example binary classification
})

df_similarity = find_uniquely_representative_quote(df)


# Import the quotes
quotes_df = pd.read_csv('data/cleaned/transcripts_clean.csv')

# Remove rows with who_speaking == 'Lead' or who_speaking == 'Other'
quotes_df = quotes_df[~quotes_df['who_speaking'].isin(['Lead', 'Other'])]

# Rename "speech_english" column to "quotes"
quotes_df = quotes_df.rename(columns={'speech_english': 'quote'})

# Keep only first 50 randomly sampled rows for testing

# Keep only rows where pair_includes_trans and quote are not NA
quotes_df = quotes_df.dropna(subset=['pair_includes_trans', 'quote'])

# Get an embedding for each quote and store in a new list

if load_from_pickle == False:

    # Split into calls of 100 at a time so that embedding works
    quote_embeddings = []
    for i in range(0, len(quotes_df), 100):

        # Print progress
        print(f"Processing quotes {i} to {i+100} of {len(quotes_df)}")

        quotes_subset = quotes_df['quote'].iloc[i:i+100].tolist()
        embeddings = [get_embedding(quote) for quote in quotes_subset]
        quote_embeddings.extend(embeddings)


    # Convert each item in the embeddings list into a string (so one value per row)
    quote_embeddings_string = []
    for i in range(0, len(quotes_df)):
        embed_list = quote_embeddings[i]
        embed_string = ';'.join(str(number) for number in embed_list)
        quote_embeddings_string.append(embed_string)

    # Add this to the dataframe
    quotes_df['embedding'] = quote_embeddings_string

    # # Save the embeddings to a file
    with open('data/cleaned/transcript_embeddings.pkl', 'wb') as file:
        pickle.dump(quote_embeddings, file)

    with open('data/cleaned/transcript_embeddings_ids.pkl', 'wb') as file:
        pickle.dump(quotes_df, file)

elif load_from_pickle == True:

    # LOAD FROM PICKLE FILE
    with open('data/cleaned/transcript_embeddings.pkl', 'rb') as file:
        quote_embeddings = pickle.load(file)

    # Convert the embedding to a string in the same way as above, and add as a column
    quote_embeddings_string = []
    for i in range(0, len(quotes_df)):
        embed_list = quote_embeddings[i]
        embed_string = ';'.join(str(number) for number in embed_list)
        quote_embeddings_string.append(embed_string)

    quotes_df['embedding'] = quote_embeddings_string


# Keep only first 100 rows for testing

# Get list of embeddings only for pair_includes_trans == 0
quote_embeddings_0 = [quote_embeddings[i] for i in range(len(quote_embeddings)) if quotes_df['pair_includes_trans'].iloc[i] == 0]
quote_embeddings_1 = [quote_embeddings[i] for i in range(len(quote_embeddings)) if quotes_df['pair_includes_trans'].iloc[i] == 1]


# Get mean embeddings for each group
mean_embedding_0 = np.mean(quote_embeddings_0, axis=0)
mean_embedding_1 = np.mean(quote_embeddings_1, axis=0)

# Calculate similarities to both group means for group 1 quotes
similarities_1 = cosine_similarity(quote_embeddings, [mean_embedding_1]).flatten()
similarities_0 = cosine_similarity(quote_embeddings, [mean_embedding_0]).flatten()

# Calculate the difference in similarities
similarity_diff = similarities_1 - similarities_0

# Add the similarity diff to the data frame
quotes_df['similarity_diff'] = similarity_diff
quotes_df['similarities_1'] = similarities_1
quotes_df['mean_embedding_0'] = ';'.join(str(number) for number in mean_embedding_0)
quotes_df['mean_embedding_1'] = ';'.join(str(number) for number in mean_embedding_1)

# Export quotes_df to csv
quotes_df.to_csv("data/cleaned/transcripts_with_embeddings.csv", index=False)
